{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "21eef928-9a7c-447e-b1cb-a81330027c3c",
   "metadata": {},
   "source": [
    "# Details about Nvidia Merlin"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ddbe18-a511-4182-80b5-8df0a847727c",
   "metadata": {},
   "source": [
    "Documentation links:\n",
    "\n",
    "* General documentation: https://nvidia-merlin.github.io/Merlin/stable/README.html\n",
    "* GitHub: https://github.com/NVIDIA-Merlin/dataloader\n",
    "* Docker containers: https://nvidia-merlin.github.io/Merlin/stable/containers.html\n",
    "\n",
    "\n",
    "Installation via pip:\n",
    "1. Install cudf + dask-cudf: `python -m pip install cudf-cu11==23.08 rmm-cu11==23.08 dask-cudf-cu11==23.08 --extra-index-url https://pypi.nvidia.com/`\n",
    "2. Install merlin dataloader: `python -m pip install merlin-dataloader`\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd58faa-5248-4ee8-aaf9-0273ba21baf3",
   "metadata": {},
   "source": [
    "# Details about data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27bf368e-f476-45e4-b58d-332164d9c779",
   "metadata": {},
   "source": [
    "The parquet files have the folowing columns:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0553ff7a-56bd-4027-adbb-0002691274e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n",
      "  warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n",
      "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from merlin.dtypes import boolean, float32, int64\n",
    "\n",
    "\n",
    "PARQUET_SCHEMA = {\n",
    "    'X': float32, # -> gene expression values (normalized to 10.000 counts per cell + log1p transformed)\n",
    "    'soma_joinid': int64,  # soma_joinid from CELLxGENE\n",
    "    'is_primary_data': boolean,  # binary indicator whether data is primary data or not (currently all data is primary data)\n",
    "    'dataset_id': int64,  # name of the associated data set\n",
    "    'donor_id': int64,  # name of the donor (caution! This might not be unique across datasets -> use tech_sample column instead)\n",
    "    'assay': int64,  # name of the used assay\n",
    "    'cell_type': int64,  # cell type label\n",
    "    'development_stage': int64,  # development stage label\n",
    "    'disease': int64,  # disease state label\n",
    "    'tissue': int64,  # specfic tissue label\n",
    "    'tissue_general': int64,  # general tissue label\n",
    "    'tech_sample': int64,  # batch indicator \n",
    "    'idx': int64,  # consecutive enumeration of all cells in the train, val and test data\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe2671f8-1f32-4337-b4ae-b23f3123e703",
   "metadata": {},
   "source": [
    "All categorical meta data (['dataset_id', 'donor_id', 'assay', 'cell_type', 'development_stage', 'disease', 'tissue', 'tissue_general', 'tech_sample']) are encoded as integers. \n",
    "\n",
    "The lookup tables to map the integer labels to their corresponding string labels can be found under: `join(DATA_PATH, categorical_lookup)`\n",
    "\n",
    "E.g. the mapping for the `cell_type` column can be found in the `cell_type.parquet` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3376667-0555-4d9d-8864-3b26fdf31da4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8f0be050-9c18-489f-8fef-31d402ba288c",
   "metadata": {},
   "source": [
    "# Use with PyTorch Lightning DataModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6c007e2-5ca5-4eb4-a717-354626b73a68",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from cellnet.datamodules import MerlinDataModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "53c267d4-c197-4d13-a3ab-28a49c76e3ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# path to merlin store\n",
    "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7d74cb9b-3dc5-497d-b8e1-352b16d3a437",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "datamodule = MerlinDataModule(\n",
    "    path=DATA_PATH,\n",
    "    columns=['cell_type'],\n",
    "    batch_size=2048,\n",
    "    sub_sample_frac=1., # randomly subsample data (can be between (0., 1.])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "89058ebb-62b4-4919-8d31-9850ee58656a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "cell_type: tensor([  7, 127, 152,  ...,  22, 127,   4], device='cuda:0')\n",
      "X: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "cell_type: tensor([  9, 132, 118,  ...,  14, 129, 127], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "import gc\n",
    "\n",
    "\n",
    "# get dataloaders for train, valiation and test set\n",
    "train_loader = datamodule.train_dataloader()\n",
    "val_loader = datamodule.val_dataloader()\n",
    "test_loader = datamodule.test_dataloader()\n",
    "\n",
    "\n",
    "# how to use dataloaders\n",
    "for ix, (batch, _) in enumerate(train_loader):\n",
    "    # put your training code here:\n",
    "    print('X:', batch['X'])\n",
    "    print('cell_type:', batch['cell_type'])\n",
    "\n",
    "    # Merlin tends to use a lot of GPU memory if the garbage collection isn't called regularly\n",
    "    # -> manually call python garbage collection every 10 steps \n",
    "    if ix % 10 == 0:\n",
    "        gc.collect()\n",
    "\n",
    "    # don't iterate over all traning data for this tutorial\n",
    "    if ix == 1:\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f07e2569-e3c8-4bf6-af2b-80f1c361a8c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0abb0c02-7f7f-4521-b866-45c6fa3e395d",
   "metadata": {},
   "source": [
    "# Use as standalone PyTorch DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b308cc65-a2dd-4b37-9c2c-0c324c9825d5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from os.path import join\n",
    "\n",
    "from cellnet.datamodules import merlin_dataset_factory, set_default_kwargs_dataset\n",
    "from merlin.dataloader.torch import Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "96d04467-242d-4beb-b89b-ecf7c964d3ad",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# path to merlin store\n",
    "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4732b0e7-ad05-43d7-b817-65cad0fca24e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X: tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')\n",
      "cell_type: tensor([46, 44, 38,  ..., 67, 19, 60], device='cuda:0')\n",
      "X: tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        ...,\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 1.4717, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],\n",
      "       device='cuda:0')\n",
      "cell_type: tensor([ 14,  64, 131,  ...,  60,  63, 132], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# manually create data loaders for train and validation set\n",
    "train_dataset = merlin_dataset_factory(\n",
    "    join(DATA_PATH, 'train'), \n",
    "    columns=['cell_type'], \n",
    "    dataset_kwargs=set_default_kwargs_dataset(training=True)\n",
    ")\n",
    "train_loader = Loader(train_dataset, batch_size=2048, shuffle=True)\n",
    "\n",
    "\n",
    "val_dataset = merlin_dataset_factory(\n",
    "    join(DATA_PATH, 'val'), \n",
    "    columns=['cell_type'], \n",
    "    dataset_kwargs=set_default_kwargs_dataset(training=False)\n",
    ")\n",
    "val_loader = Loader(val_dataset, batch_size=2048, shuffle=False)\n",
    "\n",
    "\n",
    "# how to use dataloaders\n",
    "for ix, (batch, _) in enumerate(train_loader):\n",
    "    # put your training code here:\n",
    "    print('X:', batch['X'])\n",
    "    print('cell_type:', batch['cell_type'])\n",
    "\n",
    "    # Merlin tends to use a lot of GPU memory if the garbage collection isn't called regularly\n",
    "    # -> manually call python garbage collection every 10 steps \n",
    "    if ix % 10 == 0:\n",
    "        gc.collect()\n",
    "\n",
    "    # don't iterate over all traning data for this tutorial\n",
    "    if ix == 1:\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85addfa5-f3da-4991-9222-1beb319be814",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
